#' This is the companion code to the post 
#' "Image-to-image translation with Pix2Pix: An implementation using Keras and eager execution"
#' on the TensorFlow for R blog.
#' 
#' https://blogs.rstudio.com/tensorflow/posts/2018-09-20-eager-pix2pix

library(keras)
library(tensorflow)
library(tfautograph)

library(tfdatasets)
library(purrr)

restore <- FALSE

data_dir <- "facades"

buffer_size <- 400
batch_size <- 1
batches_per_epoch <- buffer_size / batch_size
img_width <- 256L
img_height <- 256L

w <- 512L
w2 <- 256L

load <- function(image_file) {
  image <- tf$io$read_file(image_file)
  image <- tf$image$decode_jpeg(image)
  real_image <- image[ , 1L:w2, ]
  input_image <- image[ , (w2 + 1L):w, ]
  input_image <- k_cast(input_image, tf$float32)
  real_image <- k_cast(real_image, tf$float32)
  list(input_image, real_image)
}
  
resize <- function(input_image, real_image, height, width) {
  input_image <- tf$image$resize(input_image, list(height, width),
                                 method = tf$image$ResizeMethod$NEAREST_NEIGHBOR)
  real_image <- tf$image$resize(real_image, list(height, width),
                                method = tf$image$ResizeMethod$NEAREST_NEIGHBOR)
  
  list(input_image, real_image)
}
  
random_crop <- function(input_image, real_image) {
  stacked_image <-
    k_stack(list(input_image, real_image), axis = 1)
  cropped_image <-
    tf$image$random_crop(stacked_image, size = c(2L, img_height, img_width, 3L))
  c(input_image, real_image) %<-% list(cropped_image[1, , , ], cropped_image[2, , , ])
  list(input_image, real_image)
}

normalize <- function(input_image, real_image) {
  input_image <- (input_image / 127.5) - 1
  real_image <- (real_image / 127.5) - 1
  
  list(input_image, real_image)
  
}

random_jitter <- function(input_image, real_image) {
  # resizing to 286 x 286 x 3
  c(input_image, real_image) %<-% resize(input_image, real_image, 286L, 286L)
    # randomly cropping to 256 x 256 x 3
  c(input_image, real_image) %<-% random_crop(input_image, real_image)
    if (tf$random$uniform(shape = list()) > 0.5) {
    # random mirroring
    input_image <- tf$image$flip_left_right(input_image)
    real_image <- tf$image$flip_left_right(real_image)
  }
  list(input_image, real_image)
}

random_jitter <- tf_function(autograph(random_jitter))

load_image_train <- function(image_file) {
  c(input_image, real_image) %<-% load(image_file)
  c(input_image, real_image) %<-% random_jitter(input_image, real_image)
  #c(input_image, real_image) %<-% normalize(input_image, real_image)
  list(input_image, real_image)
}

load_image_test <- function(image_file) {
  c(input_image, real_image) %<-% load(image_file)
  c(input_image, real_image) %<-% resize(input_image, real_image, img_height, img_width)
  c(input_image, real_image) %<-% normalize(input_image, real_image)
  list(input_image, real_image)
}

train_dataset <-
  tf$data$Dataset$list_files(file.path(data_dir, "train/*.jpg")) %>%
  dataset_shuffle(buffer_size) %>%
  dataset_map(load_image_train, num_parallel_calls = 1) %>%
  dataset_batch(batch_size)

test_dataset <-
  tf$data$Dataset$list_files(file.path(data_dir, "test/*.jpg")) %>%
  dataset_map(load_image_test, num_parallel_calls = 1) %>%
  dataset_batch(batch_size)


downsample <- function(filters,
                       size,
                       apply_batchnorm = TRUE,
                       name = "downsample") {
  keras_model_custom(name = name, function(self) {
    self$apply_batchnorm <- apply_batchnorm
    self$conv1 <- layer_conv_2d(
      filters = filters,
      kernel_size = size,
      strides = 2,
      padding = 'same',
      kernel_initializer = initializer_random_normal(0, 0.2),
      use_bias = FALSE
    )
    if (self$apply_batchnorm) {
      self$batchnorm <- layer_batch_normalization()
    }
    
    function(x,
             mask = NULL,
             training = TRUE) {
      x <- self$conv1(x)
      if (self$apply_batchnorm) {
        x %>% self$batchnorm(training = training)
      }
      #cat("downsample (generator) output: ", x$shape$as_list(), "\n")
      x %>% layer_activation_leaky_relu()
    }
    
  })
}

upsample <- function(filters,
                     size,
                     apply_dropout = FALSE,
                     name = "upsample") {
  keras_model_custom(name = NULL, function(self) {
    self$apply_dropout <- apply_dropout
    self$up_conv <- layer_conv_2d_transpose(
      filters = filters,
      kernel_size = size,
      strides = 2,
      padding = "same",
      kernel_initializer = initializer_random_normal(),
      use_bias = FALSE
    )
    self$batchnorm <- layer_batch_normalization()
    if (self$apply_dropout) {
      self$dropout <- layer_dropout(rate = 0.5)
    }
    function(xs,
             mask = NULL,
             training = TRUE) {
      c(x1, x2) %<-% xs
      x <- self$up_conv(x1) %>% self$batchnorm(training = training)
      if (self$apply_dropout) {
        x %>% self$dropout(training = training)
      }
      x %>% layer_activation("relu")
      concat <- k_concatenate(list(x, x2))
      #cat("upsample (generator) output: ", concat$shape$as_list(), "\n")
      concat
    }
  })
}

generator <- function(name = "generator") {
  keras_model_custom(name = name, function(self) {
    self$down1 <- downsample(64, 4, apply_batchnorm = FALSE)
    self$down2 <- downsample(128, 4)
    self$down3 <- downsample(256, 4)
    self$down4 <- downsample(512, 4)
    self$down5 <- downsample(512, 4)
    self$down6 <- downsample(512, 4)
    self$down7 <- downsample(512, 4)
    self$down8 <- downsample(512, 4)
    
    self$up1 <- upsample(512, 4, apply_dropout = TRUE)
    self$up2 <- upsample(512, 4, apply_dropout = TRUE)
    self$up3 <- upsample(512, 4, apply_dropout = TRUE)
    self$up4 <- upsample(512, 4)
    self$up5 <- upsample(256, 4)
    self$up6 <- upsample(128, 4)
    self$up7 <- upsample(64, 4)
    self$last <- layer_conv_2d_transpose(
      filters = 3,
      kernel_size = 4,
      strides = 2,
      padding = "same",
      kernel_initializer = initializer_random_normal(0, 0.2),
      activation = "tanh"
    )
    
    function(x,
             mask = NULL,
             training = TRUE) {
      # x shape == (bs, 256, 256, 3)
      x1 <-
        x %>% self$down1(training = training)  # (bs, 128, 128, 64)
      x2 <- self$down2(x1, training = training) # (bs, 64, 64, 128)
      x3 <- self$down3(x2, training = training) # (bs, 32, 32, 256)
      x4 <- self$down4(x3, training = training) # (bs, 16, 16, 512)
      x5 <- self$down5(x4, training = training) # (bs, 8, 8, 512)
      x6 <- self$down6(x5, training = training) # (bs, 4, 4, 512)
      x7 <- self$down7(x6, training = training) # (bs, 2, 2, 512)
      x8 <- self$down8(x7, training = training) # (bs, 1, 1, 512)

      x9 <-
        self$up1(list(x8, x7), training = training) # (bs, 2, 2, 1024)
      x10 <-
        self$up2(list(x9, x6), training = training) # (bs, 4, 4, 1024)
      x11 <-
        self$up3(list(x10, x5), training = training) # (bs, 8, 8, 1024)
      x12 <-
        self$up4(list(x11, x4), training = training) # (bs, 16, 16, 1024)
      x13 <-
        self$up5(list(x12, x3), training = training) # (bs, 32, 32, 512)
      x14 <-
        self$up6(list(x13, x2), training = training) # (bs, 64, 64, 256)
      x15 <-
        self$up7(list(x14, x1), training = training) # (bs, 128, 128, 128)
      x16 <- self$last(x15) # (bs, 256, 256, 3)
      #cat("generator output: ", x16$shape$as_list(), "\n")
      x16
    }
  })
}


disc_downsample <- function(filters,
                            size,
                            apply_batchnorm = TRUE,
                            name = "disc_downsample") {
  keras_model_custom(name = name, function(self) {
    self$apply_batchnorm <- apply_batchnorm
    self$conv1 <- layer_conv_2d(
      filters = filters,
      kernel_size = size,
      strides = 2,
      padding = 'same',
      kernel_initializer = initializer_random_normal(0, 0.2),
      use_bias = FALSE
    )
    if (self$apply_batchnorm) {
      self$batchnorm <- layer_batch_normalization()
    }
    
    function(x,
             mask = NULL,
             training = TRUE) {
      x <- self$conv1(x)
      if (self$apply_batchnorm) {
        x %>% self$batchnorm(training = training)
      }
      x %>% layer_activation_leaky_relu()
    }
    
  })
}

discriminator <- function(name = "discriminator") {
  keras_model_custom(name = name, function(self) {
    self$down1 <- disc_downsample(64, 4, FALSE)
    self$down2 <- disc_downsample(128, 4)
    self$down3 <- disc_downsample(256, 4)
    # we are zero padding here with 1 because we need our shape to
    # go from (batch_size, 32, 32, 256) to (batch_size, 31, 31, 512)
    self$zero_pad1 <- layer_zero_padding_2d()
    self$conv <- layer_conv_2d(
      filters = 512,
      kernel_size = 4,
      strides = 1,
      kernel_initializer = initializer_random_normal(),
      use_bias = FALSE
    )
    self$batchnorm <- layer_batch_normalization()
    self$zero_pad2 <- layer_zero_padding_2d()
    self$last <- layer_conv_2d(
      filters = 1,
      kernel_size = 4,
      strides = 1,
      kernel_initializer = initializer_random_normal()
    )
    
    function(x,
             y,
             mask = NULL,
             training = TRUE) {
      x <- k_concatenate(list(x, y)) %>% # (bs, 256, 256, channels*2)
        self$down1(training = training) %>% # (bs, 128, 128, 64)
        self$down2(training = training) %>% # (bs, 64, 64, 128)
        self$down3(training = training) %>% # (bs, 32, 32, 256)
        self$zero_pad1() %>% # (bs, 34, 34, 256)
        self$conv() %>% # (bs, 31, 31, 512)
        self$batchnorm(training = training) %>%
        layer_activation_leaky_relu() %>%
        self$zero_pad2() %>% # (bs, 33, 33, 512)
        self$last() # (bs, 30, 30, 1)
      #cat("discriminator output: ", x$shape$as_list(), "\n")
      x
    }
  })
  
}

generator <- generator()
discriminator <- discriminator()

cross_entropy = tf$keras$losses$BinaryCrossentropy(from_logits = TRUE)

discriminator_loss <- function(real_output, generated_output) {
  real_loss <- cross_entropy(k_ones_like(real_output), real_output)
  generated_loss <- cross_entropy(k_zeros_like(generated_output), generated_output)
  real_loss + generated_loss
}

lambda <- 100
generator_loss <-
  function(disc_judgment, generated_output, target) {
    gan_loss <- cross_entropy(tf$ones_like(disc_judgment), disc_judgment)
    l1_loss <- tf$reduce_mean(tf$abs(target - generated_output))
    gan_loss + (lambda * l1_loss)
  }

discriminator_optimizer <- tf$optimizers$Adam(2e-4, beta_1 = 0.5)
generator_optimizer <- tf$optimizers$Adam(2e-4, beta_1 = 0.5)

checkpoint_dir <- "./checkpoints_pix2pix"
checkpoint_prefix <- file.path(checkpoint_dir, "ckpt")
checkpoint <-
  tf$train$Checkpoint(
    generator_optimizer = generator_optimizer,
    discriminator_optimizer = discriminator_optimizer,
    generator = generator,
    discriminator = discriminator
  )

generate_images <- function(generator, input, target, id) {
  prediction <- generator(input, training = TRUE)
  png(paste0("pix2pix_", id, ".png"), width = 900, height = 300)
  par(mfcol = c(1, 3))
  par(mar = c(0, 0, 0, 0),
      xaxs = 'i',
      yaxs = 'i')
  input <- input[1, , ,]$numpy() * 0.5 + 0.5
  input[input > 1] <- 1
  input[input < 0] <- 0
  plot(as.raster(input, main = "input image"))
  target <- target[1, , ,]$numpy() * 0.5 + 0.5
  target[target > 1] <- 1
  target[target < 0] <- 0
  plot(as.raster(target, main = "ground truth"))
  prediction <- prediction[1, , ,]$numpy() * 0.5 + 0.5
  prediction[prediction > 1] <- 1
  prediction[prediction < 0] <- 0
  plot(as.raster(prediction, main = "generated"))
  dev.off()
}

train <- function(dataset, num_epochs) {
  for (epoch in 1:num_epochs) {
    total_loss_gen <- 0
    total_loss_disc <- 0
    iter <- make_iterator_one_shot(train_dataset)
    
    until_out_of_range({
      batch <- iterator_get_next(iter)
      input_image <- batch[[1]]
      target <- batch[[2]]
      
      with(tf$GradientTape() %as% gen_tape, {
        with(tf$GradientTape() %as% disc_tape, {
          gen_output <- generator(input_image, training = TRUE)
          disc_real_output <-
            discriminator(input_image, target, training = TRUE)
          disc_generated_output <-
            discriminator(input_image, gen_output, training = TRUE)
          gen_loss <-
            generator_loss(disc_generated_output, gen_output, target)
          disc_loss <-
            discriminator_loss(disc_real_output, disc_generated_output)
          total_loss_gen <- total_loss_gen + gen_loss
          total_loss_disc <- total_loss_disc + disc_loss
        })
      })
      generator_gradients <- gen_tape$gradient(gen_loss,
                                               generator$trainable_variables)
      discriminator_gradients <- disc_tape$gradient(disc_loss,
                                                    discriminator$trainable_variables)
      
      generator_optimizer$apply_gradients(transpose(list(
        generator_gradients,
        generator$trainable_variables
      )))
      discriminator_optimizer$apply_gradients(transpose(
        list(discriminator_gradients,
             discriminator$trainable_variables)
      ))
      
    })
    cat("Epoch ", epoch, "\n")
    cat("Generator loss: ",
        total_loss_gen$numpy() / batches_per_epoch,
        "\n")
    cat("Discriminator loss: ",
        total_loss_disc$numpy() / batches_per_epoch,
        "\n\n")
    if (epoch %% 10 == 0) {
      test_iter <- make_iterator_one_shot(test_dataset)
      batch <- iterator_get_next(test_iter)
      input <- batch[[1]]
      target <- batch[[2]]
      generate_images(generator, input, target, paste0("epoch_", i))
    }
    if (epoch %% 10 == 0) {
      checkpoint$save(file_prefix = checkpoint_prefix)
    }
    
  }
}

if (!restore) {
  train(train_dataset, 200)
} 


checkpoint$restore(tf$train$latest_checkpoint(checkpoint_dir))

test_iter <- make_iterator_one_shot(test_dataset)
i <- 1
until_out_of_range({
  batch <- iterator_get_next(test_iter)
  input <- batch[[1]]
  target <- batch[[2]]
  generate_images(generator, input, target, paste0("test_", i))
  i <- i + 1
})



